#%%
import numpy as np, matplotlib.pyplot as plt, seaborn as sns, numpy as np, sys, os, tqdm, lmfit as lm, scipy.special as ss
import pandas as pd, scipy.stats as stats, scipy.optimize as op, scipy.interpolate as ip, pathlib
%matplotlib inline
plt.style.use('seaborn-whitegrid')
sns.set_context('poster')
sns.set_style('ticks')

wd = str(pathlib.Path(os.path.dirname(__file__)).parent.parent)
#%%
# Load dietrich data
df_diet = pd.read_csv('other_shape_data/dietrich_settling_data.csv', header=1)
plt.loglog(df_diet.X, df_diet.Y, 'o')

x_data = np.log10(df_diet.X.values)
y_data = np.log10(df_diet.Y.values)

funco = lambda x, a, b, c, d, e: a + b*x + c*x**2 + d*x**3 + e*x**4
model = lm.Model(funco)
model.set_param_hint('a', value=1, vary=True)
model.set_param_hint('b', value=1, vary=True)
model.set_param_hint('c', value=1, vary=True)
model.set_param_hint('d', value=1, vary=True)
model.set_param_hint('e', value=1, vary=True)
params = model.make_params()

results = model.fit(y_data[np.isfinite(x_data)], params, x=x_data[np.isfinite(x_data)])
a = results.result.params['a'].value
b = results.result.params['b'].value
c = results.result.params['c'].value
d = results.result.params['d'].value
e = results.result.params['e'].value
print(a, b, c, d, e)

x = np.logspace(-1,11,100)
plt.plot(x, 10**funco(np.log10(x), a, b, c, d, e), 'r-')

# func = lambda x: a + b*x - c*x**2 - d*x**3 + e*x**4
func = lambda x: -3.76715 + 1.92944*x - 0.09815*x**2 - 0.00575*x**3 + 0.00056*x**4
x = np.logspace(-1,11,100)
plt.plot(x, 10**func(np.log10(x)), 'k--')
# plt.xlim(2e5, 2e10)
# plt.ylim(100, 5e5)

xf = pd.DataFrame({'X_10': df_diet.X[df_diet.X > 2e5], 'Y_10': df_diet.Y[df_diet.X > 2e5], 'Yp_10': (df_diet['Y'][df_diet['X'] > 2e5]/10**func(np.log10(df_diet['X'][df_diet['X'] > 2e5])))**(1/3),
                'X_10_8': df_diet['X.1'][df_diet['X.1'] > 2e5], 'Y_10_8': df_diet['Y.1'][df_diet['X.1'] > 2e5], 'Yp_10_8': (df_diet['Y.1'][df_diet['X.1'] > 2e5]/10**func(np.log10(df_diet['X.1'][df_diet['X.1'] > 2e5])))**(1/3),
                'X_8_6': df_diet['X.2'][df_diet['X.2'] > 2e5], 'Y_8_6': df_diet['Y.2'][df_diet['X.2'] > 2e5], 'Yp_8_6': (df_diet['Y.2'][df_diet['X.2'] > 2e5]/10**func(np.log10(df_diet['X.2'][df_diet['X.2'] > 2e5])))**(1/3),
                'X_6_4': df_diet['X.3'][df_diet['X.3'] > 2e5], 'Y_6_4': df_diet['Y.3'][df_diet['X.3'] > 2e5], 'Yp_6_4': (df_diet['Y.3'][df_diet['X.3'] > 2e5]/10**func(np.log10(df_diet['X.3'][df_diet['X.3'] > 2e5])))**(1/3),
                'X_4_2': df_diet['X.4'][df_diet['X.4'] > 2e5], 'Y_4_2': df_diet['Y.4'][df_diet['X.4'] > 2e5], 'Yp_4_2': (df_diet['Y.4'][df_diet['X.4'] > 2e5]/10**func(np.log10(df_diet['X.4'][df_diet['X.4'] > 2e5])))**(1/3),
                'X_2_0': df_diet['X.5'][df_diet['X.5'] > 2e5], 'Y_2_0': df_diet['Y.5'][df_diet['X.5'] > 2e5], 'Yp_2_0': (df_diet['Y.5'][df_diet['X.5'] > 2e5]/10**func(np.log10(df_diet['X.5'][df_diet['X.5'] > 2e5])))**(1/3),})

                

# %%
label = {'gb': 'Spheres', 'ng': 'Natural gravel (NG) 1', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids',
        'dn': 'Tempered glass', 'jo': 'NG 2', 'mb': 'NG 3', 'mk': 'NG 4', 'pg': 'NG 5', 'sf': 'Shell fragments'}
color = {'gb': sns.xkcd_rgb['cobalt blue'], 'gg': sns.xkcd_rgb['cerulean'], 'oc': sns.xkcd_rgb['blue green'], 'ng': sns.xkcd_rgb['maize'], 'ls': sns.xkcd_rgb['blood orange'], 'bm': 'y',
         'dn': sns.xkcd_rgb['cyan'], 'jo': sns.xkcd_rgb['lime green'], 'mb': sns.xkcd_rgb['black'], 'mk': sns.xkcd_rgb['pink'], 'pg': sns.xkcd_rgb['dark red'], 'sf': sns.xkcd_rgb['dusty purple']}
marker = {'gb': 'o', 'gg': 'D', 'oc': '*', 'ng': 'v', 'ls': 's',
        'dn': 'o', 'jo': 'D', 'mb': '*', 'mk': 'v', 'pg': 's', 'sf': 'o'}
size = {'gb': 9, 'gg': 9, 'oc': 18, 'ng': 12, 'ls': 9,
        'dn': 9, 'jo': 9, 'mb': 18, 'mk': 12, 'pg': 9, 'sf': 9}
grains = ['gb', 'oc', 'gg', 'ng', 'ls']
grains_2 = ['dn', 'jo', 'mb', 'mk', 'pg', 'sf']
grains_new = {'dn': 'tg', 'jo': 'ng2', 'mb': 'ng3', 'mk': 'ng4', 'pg': 'ng5', 'sf': 'sf'}

df_shape_2022 = pd.read_csv('1_grain_properties/extra_granular_materials_2022/2_var_shape_2022.csv').set_index('Unnamed: 0').rename_axis(None)
df_SV_2022 = pd.read_csv('1_grain_properties/extra_granular_materials_2022/4_var_settling_velocities_2022.csv').set_index('Unnamed: 0').rename_axis(None)
df_SV_VES_2022 = pd.read_csv('1_grain_properties/extra_granular_materials_2022/5_var_VES_settling_velocity_2022.csv').set_index('Unnamed: 0').rename_axis(None)
df_AoR_2022 = pd.read_csv('1_grain_properties/extra_granular_materials_2022/3_var_ARS_2022.csv').set_index('Unnamed: 0').rename_axis(None)
df_shape = pd.read_csv('1_grain_properties/2_var_shape.csv').set_index('Unnamed: 0').rename_axis(None)
df_SV = pd.read_csv('1_grain_properties/4_var_settling_velocities.csv').set_index('Unnamed: 0').rename_axis(None)
df_SV_VES = pd.read_csv('1_grain_properties/5_var_VES_settling_velocity.csv').set_index('Unnamed: 0').rename_axis(None)
df_AoR = pd.read_csv('1_grain_properties/3_var_ARS.csv').set_index('Unnamed: 0').rename_axis(None)


#%%
plot_xtra_grains_flag = False
fig, (ax2, ax1) = plt.subplots(2, 1, figsize=(7, 8))
fs = 20

# plot background data for axis 1
x_val = {'Yp_10': 1, 'Yp_10_8': .9, 'Yp_8_6': .7, 'Yp_6_4': .5, 'Yp_4_2': .3, 'Yp_2_0': .1}
fac = 0.03
for name in ['Yp_10', 'Yp_10_8', 'Yp_8_6', 'Yp_6_4', 'Yp_4_2', 'Yp_2_0']:
    y = xf[name].values
    if name is 'Yp_10':
        ax1.plot((x_val[name]*np.random.rand(y.size) * fac + x_val[name] - fac/2), 1/y**2, '.', c='k', label='Ref. 24', alpha=.3)
    else:
        ax1.plot((x_val[name]*np.random.rand(y.size) * fac + x_val[name] - fac/2), 1/y**2, '.', c='k', alpha=.3)

ax1.set_xlabel(r'Corey shape factor, $S_f = c/\sqrt{ab}$', fontsize=fs)
ax1.set_ylabel('Relative coeff. of drag,\n$C_{D_{settle}}/C_{o}$', fontsize=fs)
ax1.set_xscale('linear')
ax1.set_yscale('log')
ax1.set_xlim(0,1.05)
ax1.set_ylim(2e-1,10)
sns.despine()

# plot background data for axis 2
df_dai = pd.read_csv('other_shape_data/dai_robinson_data.csv', header=0)
x_dai = df_dai['circularity_dai'].values[1:].astype(np.float) * 0.3 + 0.7 # data adjusted to account for accidentally mislabelled axis when copying data from figure
ax2.plot(x_dai, np.tan(np.radians(df_dai['Unnamed: 3'].values[1:].astype(np.float))), 'D', markersize=6, c='k', label='Ref. 21', alpha=0.3)

df_carrigy = pd.read_csv('other_shape_data/carrigy.csv', header=1)
circ = (df_carrigy.X/100) * np.pi/4 + (1-df_carrigy.X/100)
ang = df_carrigy.Y / 3 + 20 
ax2.plot(circ, np.tan(np.radians(ang)), 'o', markersize=4, c='k', label='Ref. 22', alpha=0.3)

df_robinson = pd.read_csv('other_shape_data/robinson_circularity.csv', header=0)
ax2.plot(df_robinson['grains'].values[1:].astype(np.float), np.tan(np.radians(df_robinson['Unnamed: 1'].values[1:].astype(np.float))), '*', c='k', label='Ref. 23', alpha=0.3)
ax2.plot(df_robinson['sand_sphere_mix'].values[1:].astype(np.float), np.tan(np.radians(df_robinson['Unnamed: 3'].values[1:].astype(np.float))), '*', c='k', alpha=0.3)

ax2.legend(ncol=1, fontsize=14, loc=3)
ax2.set_xlim(0.7,1.01)
ax2.set_ylim(0.2,1)
ax2.set_xscale('linear')
ax2.set_yscale('linear')
ax2.set_xlabel('Circularity, $S_c = 4\pi A/P^2$', fontsize=fs)
ax2.set_ylabel('Coeff. of static friction, $\mu_s$', fontsize=fs)
# if plot_xtra_grains_flag is True:
    # ax1.text(1, 9, 'b', fontsize=18)
    # ax2.text(1.0, .95, 'a', fontsize=18)
# else:
    # ax1.text(1, 9, 'd', fontsize=18)
    # ax2.text(1.0, .95, 'c', fontsize=18)
sns.despine()

# plot new grain data (2022)
if plot_xtra_grains_flag is True:
    for grain_in in grains_2:
        grain = grains_new[grain_in]
        yerr = 2*((df_SV_VES_2022.ws_sphere[grain]/df_SV_2022.vel_mean[grain])**2)*np.sqrt((df_SV_VES_2022.dws_sphere[grain]/df_SV_VES_2022.ws_sphere[grain])**2 + (df_SV_2022.dvel_mean[grain]/df_SV_2022.vel_mean[grain])**2)/np.sqrt(df_SV.vels.size)
        ax1.errorbar(df_shape_2022.CSF[grain], (df_SV_VES_2022.ws_sphere[grain]/df_SV_2022.vel_mean[grain])**2, c='w', fmt=marker[grain_in], markersize=size[grain_in]+1)
        ax1.errorbar(df_shape_2022.CSF[grain], (df_SV_VES_2022.ws_sphere[grain]/df_SV_2022.vel_mean[grain])**2, xerr=df_shape_2022.dCSF[grain], yerr=yerr, ecolor=sns.xkcd_rgb['grey'], elinewidth=.9, capthick=.9, capsize=1, color=color[grain_in], markeredgecolor='w', markeredgewidth=1,  fmt=marker[grain_in], markersize=size[grain_in], label=label[grain_in])
        
    def ellipse_perm(a,b): return 4*a*ss.ellipe(1.0 - b**2/a**2)  
    def ellipse_area(a,b): return np.pi * a * b
    def ellipse_circ(a,b): return 4*np.pi*ellipse_area(a,b)/ellipse_perm(a,b)**2

    Circ2 = {grains_new[grain_in]: ellipse_circ(df_shape_2022.C[grains_new[grain_in]], df_shape_2022.A[grains_new[grain_in]]) for grain_in in grains_2} # Corey shape factor for grain materials with variable shapes
    for grain_in in grains_2:
        grain = grains_new[grain_in]
        dAoR = 1
        dCirc2 = 0.02
        ax2.errorbar(Circ2[grain], np.tan(np.radians(df_AoR_2022.mean_angle[grain])), c='w', fmt=marker[grain_in], markersize=size[grain_in]*0.8+0.5)
        ax2.errorbar(Circ2[grain], np.tan(np.radians(df_AoR_2022.mean_angle[grain])), xerr=dCirc2, yerr=np.tan(np.radians(df_AoR_2022.std_angle[grain])), ecolor=sns.xkcd_rgb['grey'], elinewidth=.9, capthick=.9, capsize=1, c=color[grain_in], markeredgecolor='w', markeredgewidth=1, fmt=marker[grain_in], markersize=size[grain_in]*0.8)


Circ = {
        'ng': 0.816,
        'ls': 0.823,
        'gg': 0.812,
        'oc': np.mean([0.82, 0.98]),
        'gb': 1.
        }
dCirc = {
        'ng': 0.016,
        'ls': 0.0093,
        'gg': 0.021,
        'oc': np.std([0.2, 0.3]),
        'gb': 0
        }


# plot flume shape data
for grain in grains:
    ax2.errorbar(Circ[grain], np.tan(np.radians(df_AoR.mean_angle[grain])), c='w', fmt=marker[grain], markersize=size[grain]+1)
    ax2.errorbar(Circ[grain], np.tan(np.radians(df_AoR.mean_angle[grain])), xerr=dCirc[grain], yerr=np.tan(np.radians(df_AoR.std_angle[grain])), ecolor=sns.xkcd_rgb['grey'], elinewidth=.9, capthick=.9, capsize=1, c=color[grain], markeredgecolor='w', markeredgewidth=1, fmt=marker[grain], markersize=size[grain])
    
for grain in grains:
    yerr = 2*((df_SV_VES.ws_sphere[grain]/df_SV.vel_mean[grain])**2)*np.sqrt((df_SV_VES.dws_sphere[grain]/df_SV_VES.ws_sphere[grain])**2 + (df_SV.dvel_mean[grain]/df_SV.vel_mean[grain])**2)
    ax1.errorbar(df_shape.CSF[grain], (df_SV_VES.ws_sphere[grain]/df_SV.vel_mean[grain])**2, c='w', fmt=marker[grain], markersize=size[grain]+1)
    ax1.errorbar(df_shape.CSF[grain], (df_SV_VES.ws_sphere[grain]/df_SV.vel_mean[grain])**2, xerr=df_shape.dCSF[grain], yerr=yerr, ecolor=sns.xkcd_rgb['grey'], elinewidth=.9, capthick=.9, capsize=1, color=color[grain], markeredgecolor='w', markeredgewidth=1,  fmt=marker[grain], markersize=size[grain], label=label[grain])

if plot_xtra_grains_flag is True: ax1.legend(ncol=3, fontsize=12, loc=3)
else: ax1.legend(ncol=1, fontsize=13, loc=3)

for grain in grains:
    ax2.errorbar(Circ[grain], np.tan(np.radians(df_AoR.mean_angle[grain])), c='w', fmt=marker[grain], markersize=size[grain]+1)
    ax2.errorbar(Circ[grain], np.tan(np.radians(df_AoR.mean_angle[grain])), xerr=dCirc[grain], yerr=np.tan(np.radians(df_AoR.std_angle[grain])), ecolor=sns.xkcd_rgb['grey'], elinewidth=.9, capthick=.9, capsize=1, c=color[grain], markeredgecolor='w', markeredgewidth=1, fmt=marker[grain], markersize=size[grain])

plt.tight_layout()
if plot_xtra_grains_flag is False: plt.savefig('Figure_1.png', dpi=300)
if plot_xtra_grains_flag is True: plt.savefig('Figure_1_xtra_grains.png', dpi=300)


# %%
